import yaml
from train import train_model
from dataloader.dataloader import get_dataloader  # 你自己的数据加载函数

def main():
    # 读取配置
    with open('config.yaml', 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)

    # 构建 dataloader
    train_loader, train_dataset = get_dataloader(config, mode="train")
    val_loader, val_dataset = get_dataloader(config, mode="val")

    # 调用训练函数（模型、优化器、损失函数都在内部构建）
    train_model(
        train_loader=train_loader,
        train_dataset=train_dataset,
        val_loader=val_loader,
        val_dataset=val_dataset,
        config=config
    )

if __name__ == '__main__':
    main()
